package cn.edu.dgut.sw.lab.security.oauth2.client.userinfo;

import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequestEntityConverter;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;

import java.util.*;

/**
 * @author Sai
 * @see OAuth2UserService
 * @see OAuth2UserRequest
 * @see OAuth2User
 * @see DefaultOAuth2User
 * @since 1.0
 * 2018-8-23
 * 2018-9-29
 * 2018-11-30
 */
public final class SaiOAuth2UserService extends DefaultOAuth2UserService {

    public SaiOAuth2UserService() {
        setRequestEntityConverter(new SaiOAuth2UserRequestEntityConverter());

        RestTemplate restTemplate = new RestTemplate();
        restTemplate.getMessageConverters().add(new SaiMappingJackson2HttpMessageConverter());
        setRestOperations(restTemplate);
    }

    @Override
    public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException {

        DefaultOAuth2User oauth2User = (DefaultOAuth2User) super.loadUser(userRequest);

        // QQ登录时nickname会重复不能作主键，构造一个新的OAuth2User,把openid加进属性，并作为主键
        if (userRequest.getClientRegistration().getRegistrationId().equals("qq")) {
            Map<String, Object> attributes = new LinkedHashMap<>(oauth2User.getAttributes());
            attributes.put("openid", userRequest.getAdditionalParameters().get("openid"));
            return new DefaultOAuth2User(new LinkedHashSet<>(oauth2User.getAuthorities()), attributes, "openid");
        }

        return oauth2User;
    }

    /**
     * @author Sai
     * @since 1.0.5
     * <p>
     * <p>
     * 自定义{@link OAuth2UserRequestEntityConverter},添加额外的reques请求参数。
     * <p>
     * (QQ、weixin、dgut）调用获取用户信息接口时，需要额外传递参数openid;
     * (QQ)还需要额外oauth_consumer_key参数
     */
    private static final class SaiOAuth2UserRequestEntityConverter extends OAuth2UserRequestEntityConverter {
        @SuppressWarnings("unchecked")
        @Override
        public RequestEntity<?> convert(OAuth2UserRequest userRequest) {
            String registrationId = userRequest.getClientRegistration().getRegistrationId();
            MultiValueMap<String, String> formParameters = new LinkedMultiValueMap<>();
            RequestEntity<?> request = super.convert(userRequest);

            if (Arrays.asList("weixin", "dgut", "qq").contains(registrationId)) {
                assert request != null;
                formParameters.addAll((MultiValueMap<String, String>) request.getBody());

                if (registrationId.equals("qq"))
                    formParameters.add("oauth_consumer_key", userRequest.getClientRegistration().getClientId());

                formParameters.add("openid", (String) userRequest.getAdditionalParameters().get("openid"));

                return new RequestEntity<>(formParameters, request.getHeaders(), request.getMethod(), request.getUrl());
            }

            return request;
        }
    }

    /**
     * @author Sai
     * @since 1.0.4
     * <p>
     * 自定义jackson信息转换器，增加对text/plain与text/html的MediaType的支持（QQ、weixin、dgut），默认时只支持application/json的contentType
     */
    class SaiMappingJackson2HttpMessageConverter extends MappingJackson2HttpMessageConverter {

        SaiMappingJackson2HttpMessageConverter() {
            List<MediaType> mediaTypes = new ArrayList<>();
            mediaTypes.add(MediaType.TEXT_PLAIN);
            mediaTypes.add(MediaType.TEXT_HTML);
            setSupportedMediaTypes(mediaTypes);
        }
    }
}
